{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "dotnet_interactive": {
     "language": "fsharp"
    },
    "vscode": {
     "languageId": "polyglot-notebook"
    }
   },
   "outputs": [],
   "source": [
    "#r \"nuget: TorchSharp-cpu\"\n",
    "\n",
    "open TorchSharp\n",
    "open type TorchSharp.torch\n",
    "open type TorchSharp.TensorExtensionMethods\n",
    "open type TorchSharp.torch.distributions"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Training with a Learning Rate Scheduler\n",
    "\n",
    "In Tutorial 6, we saw how the optimizers took an argument called the 'learning rate,' but didn't spend much time on it except to say that it could have a great impact on how quickly training would converge toward a solution. In fact, you can choose the learning rate (LR) so poorly, that the training doesn't converge at all.\n",
    "\n",
    "If the LR is too small, training will go very slowly, wasting compute resources. If it is too large, training could result in numeric overflow, or NaNs. Either way, you're in trouble."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To further complicate matters, it turns out that the learning rate shouldn't necessarily be constant. Training can go much better if the learning rate starts out relatively large and gets smaller as you get closer to the end.\n",
    "\n",
    "There's a solution for this, called a Learning Rate Scheduler. An LRS instance has access to the internal state of the optimizer, and can modify the LR as it goes along. There are several algorithms for scheduling, of which TorchSharp currently implements a significant subset."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Before demonstrating, let's have a model and a baseline training loop."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "dotnet_interactive": {
     "language": "fsharp"
    },
    "vscode": {
     "languageId": "polyglot-notebook"
    }
   },
   "outputs": [],
   "source": [
    "type Trivial() as this = \n",
    "    inherit nn.Module<Tensor,Tensor>(\"Trivial\")\n",
    "\n",
    "    let lin1 = nn.Linear(1000L, 100L)\n",
    "    let lin2 = nn.Linear(100L, 10L)\n",
    "\n",
    "    do\n",
    "        this.RegisterComponents()\n",
    "\n",
    "    override _.forward(input) = \n",
    "    \n",
    "        use x = lin1.forward(input)\n",
    "        use y = nn.functional.relu(x)\n",
    "        lin2.forward(y)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To demonstrate how to correctly use an LR scheduler, our training data needs to look more like real training data, that is, it needs to be divided into batches."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "dotnet_interactive": {
     "language": "fsharp"
    },
    "vscode": {
     "languageId": "polyglot-notebook"
    }
   },
   "outputs": [],
   "source": [
    "let learning_rate = 0.01\n",
    "let model = Trivial()\n",
    "\n",
    "let data = [for i = 1 to 16 do rand(32,1000)]  // Our pretend input data\n",
    "let result = [for i = 1 to 16 do rand(32,10)]  // Our pretend ground truth.\n",
    "\n",
    "let loss x y = nn.functional.mse_loss(x,y)\n",
    "\n",
    "let optimizer = torch.optim.SGD(model.parameters(), learning_rate)\n",
    "\n",
    "for epoch = 1 to 300 do\n",
    "\n",
    "    for idx = 0 to data.Length-1 do\n",
    "        // Compute the loss\n",
    "        let pred = model.forward(data.[idx])\n",
    "        let output = loss pred result.[idx]\n",
    "\n",
    "        // Clear the gradients before doing the back-propagation\n",
    "        model.zero_grad()\n",
    "\n",
    "        // Do back-progatation, which computes all the gradients.\n",
    "        output.backward()\n",
    "\n",
    "        optimizer.step() |> ignore\n",
    "\n",
    "let pred = model.forward(data.[0])\n",
    "(loss pred result.[0]).item<single>()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "When I ran this, the loss was down to 0.051 after 3 seconds. (It took longer the first time around.)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## StepLR\n",
    "\n",
    "StepLR uses subtraction to adjust the learning rate every so often. The difference it makes to the training loop is that you wrap the optimizer, and then call `step` on the scheduler (once per epoch) as well as the optimizer (once per batch)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "dotnet_interactive": {
     "language": "fsharp"
    },
    "vscode": {
     "languageId": "polyglot-notebook"
    }
   },
   "outputs": [],
   "source": [
    "let learning_rate = 0.01\n",
    "let model = Trivial()\n",
    "\n",
    "let data = [for i = 1 to 16 do rand(32,1000)]  // Our pretend input data\n",
    "let result = [for i = 1 to 16 do rand(32,10)]  // Our pretend ground truth.\n",
    "\n",
    "let loss x y = nn.functional.mse_loss(x,y)\n",
    "\n",
    "let optimizer = torch.optim.SGD(model.parameters(), learning_rate)\n",
    "let scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 25, 0.95, verbose=true)\n",
    "\n",
    "for epoch = 1 to 300 do\n",
    "\n",
    "    for idx = 0 to data.Length-1 do\n",
    "        // Compute the loss\n",
    "        let pred = model.forward(data.[idx])\n",
    "        let output = loss pred result.[idx]\n",
    "\n",
    "        // Clear the gradients before doing the back-propagation\n",
    "        model.zero_grad()\n",
    "\n",
    "        // Do back-progatation, which computes all the gradients.\n",
    "        output.backward()\n",
    "\n",
    "        optimizer.step() |> ignore\n",
    "\n",
    "    scheduler.step() |> ignore\n",
    "\n",
    "let pred = model.forward(data.[0])\n",
    "(loss pred result.[0]).item<single>()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Well, that was underwhelming. The loss (in my case) went up a bit, so that's nothing to get excited about. For this trivial model, using a scheduler isn't going to make a huge difference, and it may not make much of a difference even for complex models. It's very hard to know until you try it, but now you know how to try it out. If you try this trivial example over and over, you will see that the results vary quite a bit. It's simply too simple.\n",
    "\n",
    "Regardless, you can see from the verbose output that the learning rate is adjusted as the epochs proceed."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".NET (F#)",
   "language": "F#",
   "name": ".net-fsharp"
  },
  "language_info": {
   "name": "F#"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
